import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax
from genjax import GenerativeFunction, ChoiceMap, Selection, trace
sns.set_theme(style="white")
# Pretty printing.
console = genjax.pretty(width=70)
# Reproducibility.
key = jax.random.PRNGKey(314159)Implementing the builtin modeling language
One key property of the generative function interface is that it enables a separation between model and inference code - providing an abstraction layer that facilitates the development of modular model pieces, and then inference pieces that abstract over the implementation of the interface.
Now, implementing the interface on objects, and composing them in various ways (by e.g. specializing the implementation of the interface functions to support any intended composition) is a valid way to construct new generative functions. In fact, this is the pattern which generative function combinators follow - they accept generative functions as input, and produce new generative functions whose implementations are specialized to represent some specific pattern of computation.
Explicitly constructing generative functions using languages of objects, however, can often feel unwieldy. Part of the way that GenJAX (and Gen.jl) alleviates this restriction is by exposing languages which construct generative functions from programs. This drastically increases the expressivity available to the programmer.
In GenJAX, here’s an example of the BuiltinGenerativeFunction language:
@genjax.gen
def model(x):
y = genjax.trace("y", genjax.Normal)(x, 1.0)
z = genjax.trace("z", genjax.Normal)(y + x, 1.0)
return zWhen we apply one of the interface functions to this object, we get the associated data types that we expect.
key, tr = model.simulate(key, (1.0,))
trBuiltinTrace ├── gen_fn │ └── BuiltinGenerativeFunction │ └── source │ └── <function model> ├── args │ └── tuple │ └── (const) 1.0 ├── retval │ └── f32[] ├── choices │ └── Trie │ ├── :y │ │ └── DistributionTrace │ │ ├── gen_fn │ │ │ └── _Normal │ │ ├── args │ │ │ └── tuple │ │ │ ├── (const) 1.0 │ │ │ └── (const) 1.0 │ │ ├── value │ │ │ └── f32[] │ │ └── score │ │ └── f32[] │ └── :z │ └── DistributionTrace │ ├── gen_fn │ │ └── _Normal │ ├── args │ │ └── tuple │ │ ├── f32[] │ │ └── (const) 1.0 │ ├── value │ │ └── f32[] │ └── score │ └── f32[] ├── cache │ └── Trie └── score └── f32[]
How exactly do we do this? In this notebook, you’re going to find out. You’ll also get a chance to explore some of the capabilities which JAX exposes to library designers. Ideally, you’ll also get a sense of some of the limitations of JAX (and GenJAX) - which are restricted to support programs which are amenable to GPU/TPU acceleration.
The magic of JAX
Let’s examine the generative function object:
modelBuiltinGenerativeFunction
└── source
└── <function model>
All the decorator genjax.gen does is wrap the function into this object. It holds a reference to the function we defined above.
But clearly, we need to somehow get inside that function - because we’re recording data onto the BuiltinTrace which come from intermediate results of the execution of the function.
That’s where JAX comes in - JAX provides a way to trace pure, numerical Python programs - enabling us to construct program transformations which return new functions that compute different semantics from the original function.1
1 Program tracing is an approach which has its roots in automatic differentiation. If you’re interesting in this technique, I cannot recommend Autodidax: JAX core from scratch enough. It will introduce you to enough interesting PL ideas to keep you occupied for months, if not years.
Let’s utilize one of JAX’s interpreters to construct an intermediate representation of the function which our generative function object holds reference to:
jaxpr = jax.make_jaxpr(model.source)(1.0)
jaxpr{ lambda ; a:f32[]. let b:key<fry>[] = random_seed[impl=fry] 0 _:u32[2] = random_unwrap b c:f32[] = trace[addr=y gen_fn=_Normal() tree_in=PyTreeDef((*, *))] a 1.0 d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a e:f32[] = add c d f:key<fry>[] = random_seed[impl=fry] 0 _:u32[2] = random_unwrap f g:f32[] = trace[addr=z gen_fn=_Normal() tree_in=PyTreeDef((*, *))] e 1.0 in (g,) }
So jax.make_jaxpr takes a function f :: A -> B and returns a function f :: A -> Jaxpr, where Jaxpr is the program representation above.
When we run this function using Python’s interpreter, JAX lifts the input to something called a Tracer, JAX keeps an internal stack of interpreters which redirect infix operations on Tracer instances and modify their behavior. Additionally, JAX exposes new primitives (like all the NumPy primitives) which wrap a function called bind. bind takes in Tracer arguments, looks through them (and the interpreter stack), selects the interpreter which should handle the call - and then the interpreter is allowed to process_primitive - invoking the semantics which the interpreter defines for that primitive.
jax.make_jaxpr uses the above process to walk the program, and construct the above intermediate representation.
Now, the point of having this representation is that we can transform it further! We can lower it to other representations (including things like XLA - the linear algebra accelerator that JAX utilizes to go high performance). We could also write another interpreter which walks this representation, invokes other primitives with bind, etc - deferring further transformation to the next interpreter in line.
This (admittedly rough description) above is the secret behind JAX’s compositional transformations.
New semantics via program transformations
Let’s examine the representation once more.
jaxpr{ lambda ; a:f32[]. let b:key<fry>[] = random_seed[impl=fry] 0 _:u32[2] = random_unwrap b c:f32[] = trace[addr=y gen_fn=_Normal() tree_in=PyTreeDef((*, *))] a 1.0 d:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a e:f32[] = add c d f:key<fry>[] = random_seed[impl=fry] 0 _:u32[2] = random_unwrap f g:f32[] = trace[addr=z gen_fn=_Normal() tree_in=PyTreeDef((*, *))] e 1.0 in (g,) }
You’ll notice that there is an intrinsic called trace here - which looks suspiciously similar to genjax.trace above.
trace is a custom primitive that GenJAX defines - by defining a new primitive, we can place a stub in the intermediate representation, which we can further transform to implement the semantics we wish to express.
A high level view
Now, we need to transform it! Here’s where some serious design decisions enter into the picture.
One thing you might notice about the Jaxpr above is that the the arity of the function is fixed, and so is the arity of the return value. But when we call simulate on our model - we get out something which is not a h :: f32[] (it’s actually a jax.Pytree with a lot more data - so we’d expect a lot more return values in the Jaxpr2.
2 JAX flattens/unflattens Pytree instances on each side of the IR boundary - the IR is strongly typed, but only natively supports a few base types, and a few composite array types.
What gives?
Here’s where JAX’s support for compositional application of interpreters comes into play.
Instead of attempting to modify the IR above to change the arity of everything (a process which the authors expect would be quite painful, and buggy) - we can write another interpreter which walks the IR and evaluates it, but that interpreter can keep track of the state that we want to put into the BuiltinTrace at the end of the interface invocation.
Then, we can stage out that interpreter to support JIT compilation, etc. I’ll describe the process below in pseudo-types:
We start with f :: A -> B, and we stage it to get a new function f' :: Type[A] -> Jaxpr, then we write an interpreter I with signature I :: (Jaxpr, A) -> (B, State). The application of I itself can also be staged.
So this is really nice - we don’t have to munge the IR manually, we just get to write an interpreter to do the transformation for us. That’s the power that JAX provides for us!
Interpreter design decisions
With the high-level view in mind, we’ll examine two of the interface implementations. The first is simulate - likely the easiest implementation to understand3. The second is update.
3 For this notebook, we’re going to ignore the inference math that we wish to support!
Now, in GenJAX, the interpreter is written to be re-usable for each of the interface functions. Because we’ve chosen to re-use the interpreter (and parametrize the transformation semantics by configuring the interpreter in other ways – besides the implementation), you’re going to see some complexity right out the gate.
The reason why this complexity is there is because we wish to expose incremental computing optimizations in update. To support this customization, the interpreter can best be described as a propagation interpreter - similar to Julia’s abstract interpretation machinery (if you’re familiar). A propagation interpreter treats the Jaxpr as an undirected graph - and performs interpretation by iterating until a fixpoint condition is satisfied.
The high level pattern from the previous section is still true! But if you’ve written interpreters for something like Structure and Interpretation of Computer Programs before - this interpreter might be a slight shock to the system.
Here’s a boiled down form of the simulate_transform:
def simulate_transform(f, **kwargs):
def _inner(key, args):
# Step 1: stage out the function to a `Jaxpr`.
closed_jaxpr, (flat_args, in_tree, out_tree) = stage(f)(
key, *args, **kwargs
)
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
# Step 2: create a `Simulate` instance, which we parametrize
# the propagation interpreter with.
#
# `Bare` is an instance of something called a `Cell` - the
# objects which the propagation interpreter reasons about.
handler = Simulate()
final_env, ret_state = propagate(
Bare,
bare_propagation_rules,
jaxpr,
[Bare.new(v) for v in consts],
list(map(Bare.new, flat_args)),
[Bare.unknown(var.aval) for var in jaxpr.outvars],
handler=handler,
)
# Step 3: when the interpreter finishes, we read the values
# out of its environment.
flat_out = safe_map(final_env.read, jaxpr.outvars)
flat_out = map(lambda v: v.get_val(), flat_out)
key_and_returns = jtu.tree_unflatten(out_tree, flat_out)
key, *retvals = key_and_returns
retvals = tuple(retvals)
# Here's the handler state - remember the signature from
# above `I :: (Jaxpr, A) -> (B, State)`, these fields
# below are the `State`.
score = handler.score
chm = handler.choice_state
cache = handler.cache_state
# This returns all the things which we want to put
# into `BuiltinTrace`.
return key, (f, args, retvals, chm, score), cache
return _innerAnd, just to show you that this is the key behind how we implement simulate, I’ve copied the BuiltinGenerativeFunction class method for simulate below:
def simulate(self, key, args, **kwargs):
assert isinstance(args, Tuple)
key, (f, args, r, chm, score), cache = simulate_transform(
self.source, **kwargs
)(key, args)
return key, BuiltinTrace(self, args, r, chm, cache, score)We’ll discuss propagate in a moment - but a few high-level things.
Note that the simulate method can be staged out / used with JAX’s interfaces:
jitted = jax.jit(model.simulate)
key, tr = jitted(key, (1.0,))
trBuiltinTrace ├── gen_fn │ └── BuiltinGenerativeFunction │ └── source │ └── <function model> ├── args │ └── tuple │ └── f32[] ├── retval │ └── f32[] ├── choices │ └── Trie │ ├── :y │ │ └── DistributionTrace │ │ ├── gen_fn │ │ │ └── _Normal │ │ ├── args │ │ │ └── tuple │ │ │ ├── f32[] │ │ │ └── f32[] │ │ ├── value │ │ │ └── f32[] │ │ └── score │ │ └── f32[] │ └── :z │ └── DistributionTrace │ ├── gen_fn │ │ └── _Normal │ ├── args │ │ └── tuple │ │ ├── f32[] │ │ └── f32[] │ ├── value │ │ └── f32[] │ └── score │ └── f32[] ├── cache │ └── Trie └── score └── f32[]
That’s because simulate_transform and the interpreter implementation itself for propagate are all JAX traceable.
The only difference between the BuiltinTrace which we first generated at the top of the notebook and this one is that jax.jit will lift the 1.0 argument to a Tracer type, versus the non-jitted interpreter which just uses the Python float value.
And again, we can also stage out our simulate implementation and get a Jaxpr back:
jax.make_jaxpr(model.simulate)(key, (1.0,)){ lambda ; a:u32[2] b:f32[]. let c:key<fry>[] = random_seed[impl=fry] 0 _:u32[2] = random_unwrap c d:key<fry>[] = random_seed[impl=fry] 0 _:u32[2] = random_unwrap d e:key<fry>[] = random_wrap[impl=fry] a f:key<fry>[2] = random_split[count=2] e g:u32[2,2] = random_unwrap f h:u32[1,2] = slice[limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1)] g i:u32[2] = squeeze[dimensions=(0,)] h j:u32[1,2] = slice[limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1)] g k:u32[2] = squeeze[dimensions=(0,)] j l:key<fry>[] = random_wrap[impl=fry] k m:u32[] = random_bits[bit_width=32 shape=()] l n:u32[] = shift_right_logical m 9 o:u32[] = or n 1065353216 p:f32[] = bitcast_convert_type[new_dtype=float32] o q:f32[] = sub p 1.0 r:f32[] = sub 1.0 -0.9999999403953552 s:f32[] = mul q r t:f32[] = add s -0.9999999403953552 u:f32[] = reshape[dimensions=None new_sizes=()] t v:f32[] = max -0.9999999403953552 u w:f32[] = erf_inv v x:f32[] = mul 1.4142135381698608 w y:f32[] = mul 1.0 x z:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b ba:f32[] = add z y bb:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b bc:f32[] = sub ba bb bd:f32[] = div bc 1.0 be:f32[] = abs bd bf:f32[] = integer_pow[y=2] be bg:f32[] = log 6.283185307179586 bh:f32[] = convert_element_type[new_dtype=float32 weak_type=False] bg bi:f32[] = add bf bh bj:f32[] = mul -1.0 bi bk:f32[] = log 1.0 bl:f32[] = sub 2.0 bk bm:f32[] = convert_element_type[new_dtype=float32 weak_type=False] bl bn:f32[] = div bj bm bo:f32[] = reduce_sum[axes=()] bn bp:f32[] = add 0.0 bo bq:f32[] = add ba b br:key<fry>[] = random_wrap[impl=fry] i bs:key<fry>[2] = random_split[count=2] br bt:u32[2,2] = random_unwrap bs bu:u32[1,2] = slice[ limit_indices=(1, 2) start_indices=(0, 0) strides=(1, 1) ] bt bv:u32[2] = squeeze[dimensions=(0,)] bu bw:u32[1,2] = slice[ limit_indices=(2, 2) start_indices=(1, 0) strides=(1, 1) ] bt bx:u32[2] = squeeze[dimensions=(0,)] bw by:key<fry>[] = random_wrap[impl=fry] bx bz:u32[] = random_bits[bit_width=32 shape=()] by ca:u32[] = shift_right_logical bz 9 cb:u32[] = or ca 1065353216 cc:f32[] = bitcast_convert_type[new_dtype=float32] cb cd:f32[] = sub cc 1.0 ce:f32[] = sub 1.0 -0.9999999403953552 cf:f32[] = mul cd ce cg:f32[] = add cf -0.9999999403953552 ch:f32[] = reshape[dimensions=None new_sizes=()] cg ci:f32[] = max -0.9999999403953552 ch cj:f32[] = erf_inv ci ck:f32[] = mul 1.4142135381698608 cj cl:f32[] = mul 1.0 ck cm:f32[] = add bq cl cn:f32[] = sub cm bq co:f32[] = div cn 1.0 cp:f32[] = abs co cq:f32[] = integer_pow[y=2] cp cr:f32[] = log 6.283185307179586 cs:f32[] = convert_element_type[new_dtype=float32 weak_type=False] cr ct:f32[] = add cq cs cu:f32[] = mul -1.0 ct cv:f32[] = log 1.0 cw:f32[] = sub 2.0 cv cx:f32[] = convert_element_type[new_dtype=float32 weak_type=False] cw cy:f32[] = div cu cx cz:f32[] = reduce_sum[axes=()] cy da:f32[] = add bp cz in (bv, b, cm, b, 1.0, ba, bo, bq, 1.0, cm, cz, da) }
Giving us our pure, array math code. You can’t help but admit that that’s pretty elegant!
How does propagate work?
Now, in this section - we’re going to talk about the nitty gritty of propagate itself. What exactly is this interpreter doing? Let’s examine the context surrounding the call to propagate:
def simulate_transform(f, **kwargs):
def _inner(key, args):
closed_jaxpr, (flat_args, in_tree, out_tree) = stage(f)(
key, *args, **kwargs
)
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
handler = Simulate()
final_env, ret_state = propagate(
# A lattice type
Bare,
# Lattice propagation rules
bare_propagation_rules,
# The Jaxpr which we wish to interpret
jaxpr,
# Trace-time constants
[Bare.new(v) for v in consts],
# Input cells
list(map(Bare.new, flat_args)),
# Output cells
[Bare.unknown(var.aval) for var in jaxpr.outvars],
# How we handle `trace`.
handler=handler,
)
...
return _innerFirst, we stage our model function into a Jaxpr - when we perform the staging process, everything (e.g. custom datatypes which are Pytree implementors) gets flattened out to array leaves.
After we stage, we collect all the data which we want to use to initialize our interpreter’s environment with - but we encounter our first bit of complexity.
What is Bare? And what is a Cell? Let’s start with the latter question: a Cell is an abstract type which represents a lattice value.
To understand what a lattice value is - it’s worth gaining a high-level picture of what propagate attempts to do. propagate is an interpreter based on mixed concrete/abstract interpretation - it treats the Jaxpr as a graph - where the operations are nodes in the graph, and the SSA values (e.g. the named registers like ci, cj, etc) are edges.
The interpreter will iterate over the graph - attempting to update information about the edges by applying propagation rules (hence the name, propagate) which we define (bare_propagation_rules, above).
A propagation rule accepts a list of input cells (the SSA edges which flow into the operation) and a list of output cells. It returns a new modified list of input cells, and a new modified list of output cells, as well as a state value (in this notebook, we won’t discuss the state value - it’s unneeded for the interfaces we will describe).
The way the interpreter works is that it keeps a queue of nodes and an environment which maps SSA values to lattice values. We pop a node off the queue, grab the existing lattice values for input SSA values and output SSA values, attempt to update them using a propagation rule, and then store the update in the environment. In addition, after we attempt to update the cells - we determine if the update has changed the information level of any of the cells. If the information level has changed for any cell (as measured using the partial order on lattice values), we add any nodes which the SSA value associated with that cell flows into back onto the queue.
This process describes an iterative algorithm which attempts to compute an information fixpoint - defined by a state transition function (which operates on the state of all cells in the Jaxpr - the environment) which we get to customize using propagation rules.
I’m not going to inline any of the implementation of this interpreter into this notebook. I’ll refer the reader to the implementation of the interpreter.4
4 Note that the ideas behind this interpreter are quite widespread - but the original implementation (which the GenJAX authors modified) came from Oryx, and that implementation initially came from Roy Frostig (as far as we can tell).
What happens in simulate?
Great - so how do we utilize this interpreter idea to implement the simulate_transform described above?